[PyTorch][torch.compile] Remove process group from quantizers#3104
[PyTorch][torch.compile] Remove process group from quantizers#3104pggPL wants to merge 6 commits into
Conversation
|
/te-ci pytorch L1 |
Greptile SummaryThis PR refactors amax reduction process group handling out of persistent quantizer state and into per-call arguments, making the quantizer objects free of distributed state and enabling
Confidence Score: 5/5Safe to merge; the refactoring correctly moves process groups out of persistent quantizer state and all existing amax-reduction code paths are covered by explicit per-call setup. The logic change is well-contained: each module entry point explicitly sets or clears the amax reduction group immediately before quantization, and the
Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Caller
participant Module as Linear / LNLinear / LNMlp
participant Helper as set_quantizer_amax_reduction_group
participant Q as Quantizer (source)
participant QImpl as quantize_impl
participant Result as QuantizedTensor._quantizer (copy)
Caller->>Module: "forward(inp, tp_group=...)"
Module->>Helper: set_quantizer_amax_reduction_group(input_q, tp_group if SP+col else None)
Helper->>Q: "q.with_amax_reduction = True / False"
Module->>Q: q.quantize(tensor)
Q->>QImpl: quantize_impl(tensor) — new QuantizedTensor
QImpl-->>Q: result (with _quantizer copy)
Q->>Result: if copy.with_amax_reduction — clear it
Note over Q: SOURCE q still has with_amax_reduction=True
Caller->>Module: "backward(grad, tp_group=...)"
Module->>Helper: set_quantizer_amax_reduction_group(grad_out_q, tp_group if SP+row else None)
Helper->>Q: "q.with_amax_reduction = True / False"
Note over Module: _linear_backward also resets input_q explicitly
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Caller
participant Module as Linear / LNLinear / LNMlp
participant Helper as set_quantizer_amax_reduction_group
participant Q as Quantizer (source)
participant QImpl as quantize_impl
participant Result as QuantizedTensor._quantizer (copy)
Caller->>Module: "forward(inp, tp_group=...)"
Module->>Helper: set_quantizer_amax_reduction_group(input_q, tp_group if SP+col else None)
Helper->>Q: "q.with_amax_reduction = True / False"
Module->>Q: q.quantize(tensor)
Q->>QImpl: quantize_impl(tensor) — new QuantizedTensor
QImpl-->>Q: result (with _quantizer copy)
Q->>Result: if copy.with_amax_reduction — clear it
Note over Q: SOURCE q still has with_amax_reduction=True
Caller->>Module: "backward(grad, tp_group=...)"
Module->>Helper: set_quantizer_amax_reduction_group(grad_out_q, tp_group if SP+row else None)
Helper->>Q: "q.with_amax_reduction = True / False"
Note over Module: _linear_backward also resets input_q explicitly
Reviews (8): Last reviewed commit: "Carry amax reduction group on the Quanti..." | Re-trigger Greptile |
| """Quantize tensor""" | ||
| return self.quantize(tensor) | ||
| if amax_reduction_group is None: | ||
| return self.quantize(tensor) | ||
| return self.quantize(tensor, amax_reduction_group=amax_reduction_group) |
There was a problem hiding this comment.
The
None guard here is redundant: self.quantize(tensor) and self.quantize(tensor, amax_reduction_group=None) are identical because quantize defaults the argument to None. The branch just adds noise.
| """Quantize tensor""" | |
| return self.quantize(tensor) | |
| if amax_reduction_group is None: | |
| return self.quantize(tensor) | |
| return self.quantize(tensor, amax_reduction_group=amax_reduction_group) | |
| """Quantize tensor""" | |
| return self.quantize(tensor, amax_reduction_group=amax_reduction_group) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| @property | ||
| def rht_matrix(self) -> torch.Tensor: | ||
| """RHT matrix (fetched from the process-global cache, not stored per quantizer).""" | ||
| return get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) |
There was a problem hiding this comment.
Deserialization break for old pickled
NVFP4Quantizer instances
rht_matrix is now a property that reads self._with_random_sign_mask, but _with_random_sign_mask is a new field that did not exist in pickled state produced before this change. When Python's default __setstate__ (i.e., self.__dict__.update(state)) loads an old pickle, _with_random_sign_mask is absent, so any access to the rht_matrix property raises AttributeError. A __setstate__ that infers _with_random_sign_mask from the old stored rht_matrix (or supplies a safe default) would preserve backward compatibility for serialized quantizers.
|
Blocked by FSDP bug, refactor in progress. I plan to store .amax_reduction_group in QuantizedTensor. |
There was a problem hiding this comment.
This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).
Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:
class Float8CurrentScalingQuantizer(Quantizer):
_communicator_cache = {}
@property
def amax_reduction_group(self):
if self._amax_reduction_group_id is None:
return None
return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]
@property.setter
def amax_reduction_group(self, comm):
if comm is None:
self._amax_reduction_group_id = None
self._amax_reduction_group_id = id(comm)
Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = commI'm not sure how this would interact with checkpointing though.
| dst: QuantizedTensor, | ||
| *, | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument |
There was a problem hiding this comment.
I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.
|
/te-ci pytorch L1 |
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
e9097d6 to
948cd6d
Compare
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The amax reduction process group is no longer stored persistently on a module quantizer or on a tensor's quantizer. No C++ changes. - TP sequence parallel: the group is set on the input/grad-output quantizer at point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp, ops basic_linear), replacing the setup-time _customize_quantizers wiring. - FSDP2: the group is stored on Float8Tensor/NVFP4Tensor (set in fsdp_pre_all_gather) and applied to a throwaway quantizer copy during the in-place re-quant (update_quantized / _set_data). - quantize() strips the group off the output tensor's quantizer so it never persists on any tensor's quantizer (breaks flatten/pickle otherwise). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
b8c1bec to
6c9b986
Compare
Description
This makes adding torch.compile support much easier.
Move amax reduction process group handling out of quantizer state and pass it per quantization call instead. This avoids storing process groups inside quantizers while keeping deprecated stored-group fallback behavior for compatibility.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
amax_reduction_groupthroughquantize/module call paths instead of storing it on quantizers.Checklist: